{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ae844070",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T06:42:16.395269Z",
     "start_time": "2022-10-21T06:42:16.363355Z"
    }
   },
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import dltools"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6097495b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T06:42:28.907577Z",
     "start_time": "2022-10-21T06:42:21.264876Z"
    }
   },
   "source": [
    "batch_size, max_len = 1, 64\n",
    "train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ba73a340",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T06:42:34.413510Z",
     "start_time": "2022-10-21T06:42:34.400544Z"
    }
   },
   "source": [
    "# tokens, segments, valid_lens, pred_positions, mlm_weights,mlm, nsp\n",
    "for i in train_iter:\n",
    "    break\n",
    "i"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "aea978a1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T05:55:14.310132Z",
     "start_time": "2022-10-21T05:55:14.248790Z"
    }
   },
   "source": [
    "net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],\n",
    "                    ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,\n",
    "                    num_layers=2, dropout=0.2, key_size=128, query_size=128,\n",
    "                    value_size=128, hid_in_features=128, mlm_in_features=128,\n",
    "                    nsp_in_features=128)\n",
    "devices = dltools.try_all_gpus()\n",
    "loss = nn.CrossEntropyLoss()"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1e40dd1e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T05:55:14.326089Z",
     "start_time": "2022-10-21T05:55:14.312126Z"
    }
   },
   "source": [
    "#@save\n",
    "def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,segments_X, valid_lens_x,pred_positions_X, mlm_weights_X,mlm_Y, nsp_y):\n",
    "    # 前向传播\n",
    "    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,valid_lens_x.reshape(-1),pred_positions_X)\n",
    "    # 计算遮蔽语言模型损失\n",
    "    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)\n",
    "    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n",
    "    # 计算下一句子预测任务的损失\n",
    "    nsp_l = loss(nsp_Y_hat, nsp_y)\n",
    "    l = mlm_l + nsp_l\n",
    "    return mlm_l, nsp_l, l"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b6f2e052",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T05:55:14.342046Z",
     "start_time": "2022-10-21T05:55:14.328084Z"
    }
   },
   "source": [
    "def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):\n",
    "    net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
    "    trainer = torch.optim.Adam(net.parameters(), lr=0.01)\n",
    "    step, timer = 0, dltools.Timer()\n",
    "    animator = dltools.Animator(xlabel='step', ylabel='loss',xlim=[1, num_steps], legend=['mlm', 'nsp'])\n",
    "    # 遮蔽语言模型损失的和，下一句预测任务损失的和，句子对的数量，计数\n",
    "    metric = dltools.Accumulator(4)\n",
    "    num_steps_reached = False\n",
    "    while step < num_steps and not num_steps_reached:\n",
    "        for tokens_X, segments_X, valid_lens_x, pred_positions_X,mlm_weights_X, mlm_Y, nsp_y in train_iter:\n",
    "            tokens_X = tokens_X.to(devices[0])\n",
    "            segments_X = segments_X.to(devices[0])\n",
    "            valid_lens_x = valid_lens_x.to(devices[0])\n",
    "            pred_positions_X = pred_positions_X.to(devices[0])\n",
    "            mlm_weights_X = mlm_weights_X.to(devices[0])\n",
    "            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])\n",
    "            trainer.zero_grad()\n",
    "            timer.start()\n",
    "            mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)\n",
    "            l.backward()\n",
    "            trainer.step()\n",
    "            metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)\n",
    "            timer.stop()\n",
    "            animator.add(step + 1,(metric[0] / metric[3], metric[1] / metric[3]))\n",
    "            step += 1\n",
    "            if step == num_steps:\n",
    "                num_steps_reached = True\n",
    "                break\n",
    "\n",
    "    print(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')\n",
    "    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2c0117d2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T06:50:36.775486Z",
     "start_time": "2022-10-21T06:49:30.089546Z"
    }
   },
   "source": [
    "train_bert(train_iter, net, loss, len(vocab), devices, 500)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "df3cd3aa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T06:51:37.519398Z",
     "start_time": "2022-10-21T06:51:37.504420Z"
    }
   },
   "source": [
    "def get_bert_encoding(net, tokens_a, tokens_b=None):\n",
    "    tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)\n",
    "    token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)\n",
    "    segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)\n",
    "    valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)\n",
    "    encoded_X, _, _ = net(token_ids, segments, valid_len)\n",
    "    return encoded_X"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "17204e56",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T06:52:17.859385Z",
     "start_time": "2022-10-21T06:52:17.832457Z"
    }
   },
   "source": [
    "tokens_a = ['a', 'crane', 'is', 'flying']\n",
    "encoded_text = get_bert_encoding(net, tokens_a)\n",
    "# 词元：'<cls>','a','crane','is','flying','<sep>'\n",
    "encoded_text_cls = encoded_text[:, 0, :]\n",
    "encoded_text_crane = encoded_text[:, 2, :]\n",
    "encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6230f098",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T05:55:22.422592Z",
     "start_time": "2022-10-21T05:55:22.380705Z"
    }
   },
   "source": [
    "encoded_text_crane"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f1dc03b8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-10-21T06:54:10.087474Z",
     "start_time": "2022-10-21T06:54:10.064535Z"
    }
   },
   "source": [
    "tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']\n",
    "encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)\n",
    "# 词元：'<cls>','a','crane','driver','came','<sep>','he','just',\n",
    "# 'left','<sep>'\n",
    "encoded_pair_cls = encoded_pair[:, 0, :]\n",
    "encoded_pair_crane = encoded_pair[:, 2, :]\n",
    "encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e636b1d2",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
